{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import diffusers\n",
    "from diffusers import UNet2DConditionModel\n",
    "from PIL import Image\n",
    "import sys\n",
    "from transformers import CLIPVisionModel, CLIPImageProcessor\n",
    "import torch\n",
    "\n",
    "sys.path.append(\"/home/aihao/workspace/StableDiffusionReferenceOnly/src\")\n",
    "from stable_diffusion_reference_only.pipelines.stable_diffusion_reference_only_pipeline import (\n",
    "    StableDiffusionReferenceOnlyPipeline,\n",
    ")\n",
    "from stable_diffusion_reference_only.models.unet_2d_dobule_condition import (\n",
    "    UNet2DDobuleConditionModel,\n",
    ")\n",
    "from diffusers.schedulers import DDPMScheduler\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet_config_path = \"/home/aihao/workspace/StableDiffusionReferenceOnly/src/stable_diffusion_reference_only/models/unet-2-1.json\"\n",
    "pretrained_unet_path = \"stabilityai/stable-diffusion-2-1\"\n",
    "pretrained_image_encoder_path = \"openai/clip-vit-large-patch14\"\n",
    "\n",
    "# unet_config_path = \"/home/aihao/workspace/StableDiffusionReferenceOnly/src/stable_diffusion_reference_only/models/unet_xl-base-1.0.json\"\n",
    "# pretrained_unet_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
    "# pretrained_image_encoder_path = \"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(unet_config_path) as f:\n",
    "    unet_config = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_unet = UNet2DConditionModel.from_pretrained(\n",
    "    pretrained_unet_path, subfolder=\"unet\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet = UNet2DDobuleConditionModel.from_config(pretrained_unet.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet_parameters = unet.state_dict()\n",
    "pretrained_unet_parameters = pretrained_unet.state_dict()\n",
    "for key in unet_parameters:\n",
    "    if key in pretrained_unet_parameters:\n",
    "        if unet_parameters[key].shape == pretrained_unet_parameters[key].shape:\n",
    "            unet_parameters[key] = pretrained_unet_parameters[key]\n",
    "        elif unet_parameters[key].shape < pretrained_unet_parameters[key].shape:\n",
    "            print(key)\n",
    "            unet_parameters[key] = pretrained_unet_parameters[key][\n",
    "                0 : unet_parameters[key].shape[0],\n",
    "                0 : unet_parameters[key].shape[1],\n",
    "            ]\n",
    "        else:\n",
    "            print(key)\n",
    "            unet_parameters[key] = torch.nn.functional.pad(\n",
    "                pretrained_unet_parameters[key],\n",
    "                (\n",
    "                    0,\n",
    "                    unet_parameters[key].shape[1]\n",
    "                    - pretrained_unet_parameters[key].shape[1],\n",
    "                    0,\n",
    "                    0,\n",
    "                ),\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unet.load_state_dict(unet_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in unet.state_dict():\n",
    "    if key in pretrained_unet.state_dict():\n",
    "        print(unet.state_dict()[key] == pretrained_unet.state_dict()[key])\n",
    "    else:\n",
    "        print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# vae = diffusers.AutoencoderKL.from_pretrained(\"stabilityai/sdxl-vae\")\n",
    "vae = diffusers.AutoencoderKL.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-2-1\", subfolder=\"vae\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_encoder = CLIPVisionModel.from_pretrained(pretrained_image_encoder_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clip_image_processor = CLIPImageProcessor.from_pretrained(pretrained_image_encoder_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheduler = DDPMScheduler.from_pretrained(pretrained_unet_path, subfolder=\"scheduler\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableDiffusionReferenceOnlyPipeline(\n",
    "    vae, image_encoder, clip_image_processor, unet, scheduler\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe.save_pretrained(\n",
    "    \"/home/aihao/workspace/DeepLearningContent/models/sd_reference_only/init_0.1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
